{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Action1 Santa的接待安排     \n",
    "\n",
    "圣诞节前100天，Santa开放了workshop，欢迎以家庭单位的参观，如何更合理的安排这些家庭参观？     每个家庭有10个选择choice0-9，数字代表了距离圣诞节的天数，比如 1代表12月24日，每个家庭必须并且只安排一次参观     家庭数量 5000，即family_id 为[0, 4999]，每天访问的人数需要在125-300人     为了更合理的计算Santa的安排能力，我们使用preference cost和accounting penalty两个指标     1）preference cost，代表Santa的个性化安排能力     2）accounting penalty，代表Santa安排的财务成本     每天接待的人员数N(d)如果大于125，就会拥挤，产生过多的清洁成本     最终的 Score = preference cost + accounting penalty     最终提交每个家庭的安排 submission.csv1、使用LP对大部分家庭进行安排（30points）     2、使用MIP对剩余家庭进行安排（30points）     3、对得到的解决方案进行优化（20points）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step1 数据加载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>choice_0</th>\n",
       "      <th>choice_1</th>\n",
       "      <th>choice_2</th>\n",
       "      <th>choice_3</th>\n",
       "      <th>choice_4</th>\n",
       "      <th>choice_5</th>\n",
       "      <th>choice_6</th>\n",
       "      <th>choice_7</th>\n",
       "      <th>choice_8</th>\n",
       "      <th>choice_9</th>\n",
       "      <th>n_people</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>family_id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>52</td>\n",
       "      <td>38</td>\n",
       "      <td>12</td>\n",
       "      <td>82</td>\n",
       "      <td>33</td>\n",
       "      <td>75</td>\n",
       "      <td>64</td>\n",
       "      <td>76</td>\n",
       "      <td>10</td>\n",
       "      <td>28</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>26</td>\n",
       "      <td>4</td>\n",
       "      <td>82</td>\n",
       "      <td>5</td>\n",
       "      <td>11</td>\n",
       "      <td>47</td>\n",
       "      <td>38</td>\n",
       "      <td>6</td>\n",
       "      <td>66</td>\n",
       "      <td>61</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>100</td>\n",
       "      <td>54</td>\n",
       "      <td>25</td>\n",
       "      <td>12</td>\n",
       "      <td>27</td>\n",
       "      <td>82</td>\n",
       "      <td>10</td>\n",
       "      <td>89</td>\n",
       "      <td>80</td>\n",
       "      <td>33</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>95</td>\n",
       "      <td>1</td>\n",
       "      <td>96</td>\n",
       "      <td>32</td>\n",
       "      <td>6</td>\n",
       "      <td>40</td>\n",
       "      <td>31</td>\n",
       "      <td>9</td>\n",
       "      <td>59</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>53</td>\n",
       "      <td>1</td>\n",
       "      <td>47</td>\n",
       "      <td>93</td>\n",
       "      <td>26</td>\n",
       "      <td>3</td>\n",
       "      <td>46</td>\n",
       "      <td>16</td>\n",
       "      <td>42</td>\n",
       "      <td>39</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           choice_0  choice_1  choice_2  choice_3  choice_4  choice_5  \\\n",
       "family_id                                                               \n",
       "0                52        38        12        82        33        75   \n",
       "1                26         4        82         5        11        47   \n",
       "2               100        54        25        12        27        82   \n",
       "3                 2        95         1        96        32         6   \n",
       "4                53         1        47        93        26         3   \n",
       "\n",
       "           choice_6  choice_7  choice_8  choice_9  n_people  \n",
       "family_id                                                    \n",
       "0                64        76        10        28         4  \n",
       "1                38         6        66        61         4  \n",
       "2                10        89        80        33         3  \n",
       "3                40        31         9        59         2  \n",
       "4                46        16        42        39         4  "
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "# 数据加载\n",
    "data = pd.read_csv('./family_data.csv', index_col='family_id')\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step2，数据预处理\n",
    "1）计算Perference Cost矩阵 pcost_mat\n",
    "\n",
    "2）计算Accounting Cost矩阵 acost_mat\n",
    "\n",
    "3）计算每个家庭的人数 FAMILY_SIZE\n",
    "\n",
    "4）每个家庭的倾向选择（choice_） DESIRED\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# n代表家庭成员个数，如果满足第choice需求，需要的penalty\n",
    "def get_penalty(n, choice): # choice表示提前多少天\n",
    "    penalty = None\n",
    "    if choice == 0:\n",
    "        penalty = 0\n",
    "    if choice == 1:\n",
    "        penalty = 50\n",
    "    if choice == 2:\n",
    "        penalty = 50 + 9*n\n",
    "    if choice == 3:\n",
    "        penalty = 100 + 9 * n\n",
    "    if choice == 4:\n",
    "        penalty = 200 + 9*n\n",
    "    if choice== 5:\n",
    "        penalty = 200 + 18*n\n",
    "    if choice == 6:\n",
    "        penalty = 300 + 18 * n\n",
    "    if choice == 7:\n",
    "        penalty = 300 + 36 * n\n",
    "    if choice == 8:\n",
    "        penalty = 400 + 36 * n\n",
    "    if choice == 9:\n",
    "        penalty = 500 + (36+199) * n\n",
    "    if choice > 9:\n",
    "        penalty = 500 + (36 + 398) * n\n",
    "    return penalty\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[2236, 2236, 2236, ..., 2236, 2236, 2236],\n",
       "       [2236, 2236, 2236, ..., 2236, 2236, 2236],\n",
       "       [1802, 1802, 1802, ..., 1802, 1802,    0],\n",
       "       ...,\n",
       "       [3104, 3104,  616, ..., 3104, 3104, 3104],\n",
       "       [ 390, 2670, 2670, ..., 2670, 2670, 2670],\n",
       "       [2236, 2236, 2236, ..., 2236, 2236, 2236]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "N_DAYS = 100 # 安排的天数\n",
    "N_FAMILY = 5000 # 家庭ID的个数\n",
    "MIN_OCCUPANCY = 125 # 最小承载量\n",
    "MAX_OCCUPANCY = 300 # 最大承载量\n",
    "\n",
    "import numpy as np\n",
    "# 计算pcost_mat, 每个家庭，在什么时候（day 0-99）访问时的penalty\n",
    "# 大小5000*100的矩阵\n",
    "pcost_mat = np.full(shape=(N_FAMILY, 100),fill_value=999999)\n",
    "for f in range(N_FAMILY):\n",
    "    # 家庭成员数量\n",
    "    f_num = data.loc[f, 'n_people']\n",
    "#     print(f_num)\n",
    "    # 对于第f个家庭，初始化pcast = other choice下的penalty\n",
    "    pcost_mat[f, :] = get_penalty(f_num, 10)\n",
    "    # 计算choice 0-9 penalty\n",
    "    for choice in range(10):\n",
    "        temp = data.loc[f][choice] # choice的天数\n",
    "        penalty = get_penalty(f_num, choice)\n",
    "        pcost_mat[f, temp-1] = penalty\n",
    "pcost_mat # 0代表正好是最优选择\n",
    "# 家庭成员数为f_num=4时，other choice=2236"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5000, 100)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pcost_mat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
       "        0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
       "       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
       "        0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
       "       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
       "        0.00000000e+00, 0.00000000e+00, 0.00000000e+00],\n",
       "       ...,\n",
       "       [4.16316072e+15, 3.71482922e+15, 3.31477861e+15, ...,\n",
       "        7.46610759e+00, 8.36716954e+00, 9.37697794e+00],\n",
       "       [4.79555148e+15, 4.27883100e+15, 3.81778713e+15, ...,\n",
       "        8.43020770e+00, 7.52185316e+00, 8.43020770e+00],\n",
       "       [5.52415954e+15, 4.92860244e+15, 4.39725208e+15, ...,\n",
       "        9.51970597e+00, 8.49339085e+00, 7.57772228e+00]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算accounting penalty 矩阵，前一天的参观人数，当天的参观人数\n",
    "# (MAX_OCCUPANCY+1, MAX_OCCUPANCY+1) +1防止越界\n",
    "acost_mat = np.zeros(shape=(MAX_OCCUPANCY+1, MAX_OCCUPANCY+1), dtype=np.float64)\n",
    "for i in range(acost_mat.shape[0]):#当天安排的人数\n",
    "    for j in range(acost_mat.shape[1]): #前一天安排的人数\n",
    "        diff = abs(i-j)\n",
    "        acost_mat[i,j] = max(0, (i-125) / 400 * i ** (0.5 + diff/50.0))\n",
    "acost_mat "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([4, 4, 3, ..., 6, 5, 4], dtype=int64)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "FAMILY_SIZE = data['n_people'].values\n",
    "FAMILY_SIZE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[51, 37, 11, ..., 75,  9, 27],\n",
       "       [25,  3, 81, ...,  5, 65, 60],\n",
       "       [99, 53, 24, ..., 88, 79, 32],\n",
       "       ...,\n",
       "       [31, 65, 53, ..., 80,  2,  6],\n",
       "       [66, 91,  3, ..., 11, 25, 69],\n",
       "       [12, 10, 24, ..., 38, 17, 46]], dtype=int64)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# DESIRED 代表每个家庭choice\n",
    "# pcost_mat的时间是1-100对应的下标是0-99\n",
    "DESIRED = data.values[:, :-1] - 1\n",
    "DESIRED"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step3，使用LP和MIP求解 规划方案\n",
    "\n",
    "1）先使用LP 对绝大部分家庭进行规划\n",
    "\n",
    "2）再使用MIP 对剩余家庭进行规划\n",
    "\n",
    "3）汇总两边的结果 => 最终规划方案\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ortools.linear_solver import pywraplp\n",
    "# 线性规划，只以preference_cost为目标\n",
    "def solveLP():\n",
    "    # 线性规划优化器\n",
    "    solver = pywraplp.Solver('AssignmentProblem', pywraplp.Solver.GLOP_LINEAR_PROGRAMMING)\n",
    "#     solver = pywraplp.Solver('AssignmentProblem', pywraplp.Solver.CBC_MIXED_INTEGER_PROGRAMMING)\n",
    "    x = {} # family_id在第j天是否参观\n",
    "    # 每一天有哪些家庭\n",
    "    candidates = [[] for x in range(N_DAYS)]\n",
    "    for i in range(N_FAMILY): # family_id\n",
    "        for j in DESIRED[i, :]:#family_id 的choice\n",
    "            candidates[j].append(i) # 在第j天有第i个family参观\n",
    "            # 定义决策变量x[i,j]i代表family_id，j代表第j天参观\n",
    "            x[i, j] = solver.BoolVar('x[%i,%i]' %(i,j))\n",
    "    \n",
    "    # 约束条件\n",
    "    # 每天参观的人数 100个数， x[i,j]=0或1\n",
    "    daily_occupancy = [solver.Sum([x[i, j] * FAMILY_SIZE[i] for i in candidates[j]]) for j in range(N_DAYS)] # j代表1-100天\n",
    "    \n",
    "    # 每个家庭，在10个choice中出现的总数\n",
    "    family_presence = [solver.Sum(x[i, j] for j in DESIRED[i,:]) for i in range(N_FAMILY)]\n",
    "    \n",
    "    # 定义目标函数 preference cost部分\n",
    "    preferenece_cost = solver.Sum([pcost_mat[i,j] * x[i,j] for i in range(N_FAMILY) for j in DESIRED[i,:]])\n",
    "    \n",
    "    # 满足preference_cost最小\n",
    "    solver.Minimize(preferenece_cost)\n",
    "    # 认为增加的约束条件 Constraints\n",
    "    for j in range(N_DAYS -1): # j代表当天，j+1前一天\n",
    "        # 当前人数不超过前一天人数+25\n",
    "        solver.Add(daily_occupancy[j] - daily_occupancy[j+1]<=25)\n",
    "        solver.Add(daily_occupancy[j+1] - daily_occupancy[j]<=25)\n",
    "    \n",
    "    # 每个家庭都在10个choice中出现1次\n",
    "    for i in range(N_FAMILY):\n",
    "        solver.Add(family_presence[i] == 1)\n",
    "    #每天访问人数约束\n",
    "    for j in range(N_DAYS):\n",
    "        solver.Add(daily_occupancy[j] >= MIN_OCCUPANCY)\n",
    "        solver.Add(daily_occupancy[j] <= MAX_OCCUPANCY)\n",
    "    result = solver.Solve()\n",
    "    \n",
    "    temp = [(i,j,x[i, j].solution_value()) for i in range(N_FAMILY) for j in DESIRED[i,:] if x[i,j].solution_value() > 0 ]\n",
    "    \n",
    "    print(solver.Objective().Value()) \n",
    "    # 得到参观日期的安排\n",
    "    df = pd.DataFrame(temp, columns=['family_id','day','result'])\n",
    "    return df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "73702.31696428571\n",
      "Wall time: 13.6 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "result = solveLP() #"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>family_id</th>\n",
       "      <th>day</th>\n",
       "      <th>result</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>51</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>25</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>99</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>52</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5068</th>\n",
       "      <td>4995</td>\n",
       "      <td>15</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5069</th>\n",
       "      <td>4996</td>\n",
       "      <td>87</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5070</th>\n",
       "      <td>4997</td>\n",
       "      <td>31</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5071</th>\n",
       "      <td>4998</td>\n",
       "      <td>91</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5072</th>\n",
       "      <td>4999</td>\n",
       "      <td>12</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>4931 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      family_id  day  result\n",
       "0             0   51     1.0\n",
       "1             1   25     1.0\n",
       "2             2   99     1.0\n",
       "3             3    1     1.0\n",
       "4             4   52     1.0\n",
       "...         ...  ...     ...\n",
       "5068       4995   15     1.0\n",
       "5069       4996   87     1.0\n",
       "5070       4997   31     1.0\n",
       "5071       4998   91     1.0\n",
       "5072       4999   12     1.0\n",
       "\n",
       "[4931 rows x 3 columns]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 有些家庭选择不为1所以总数多80\n",
    "# 设置阈值\n",
    "THRS = 0.999\n",
    "#已经安排上的\n",
    "assigned_df = result[result.result > THRS]\n",
    "assigned_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>family_id</th>\n",
       "      <th>day</th>\n",
       "      <th>result</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>59</th>\n",
       "      <td>59</td>\n",
       "      <td>38</td>\n",
       "      <td>0.25</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>60</th>\n",
       "      <td>59</td>\n",
       "      <td>14</td>\n",
       "      <td>0.75</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>241</th>\n",
       "      <td>240</td>\n",
       "      <td>32</td>\n",
       "      <td>0.75</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>242</th>\n",
       "      <td>240</td>\n",
       "      <td>56</td>\n",
       "      <td>0.25</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>264</th>\n",
       "      <td>262</td>\n",
       "      <td>31</td>\n",
       "      <td>0.50</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4983</th>\n",
       "      <td>4912</td>\n",
       "      <td>8</td>\n",
       "      <td>0.40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4985</th>\n",
       "      <td>4914</td>\n",
       "      <td>38</td>\n",
       "      <td>0.60</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4986</th>\n",
       "      <td>4914</td>\n",
       "      <td>43</td>\n",
       "      <td>0.40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5033</th>\n",
       "      <td>4961</td>\n",
       "      <td>53</td>\n",
       "      <td>0.75</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5034</th>\n",
       "      <td>4961</td>\n",
       "      <td>15</td>\n",
       "      <td>0.25</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>138 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      family_id  day  result\n",
       "59           59   38    0.25\n",
       "60           59   14    0.75\n",
       "241         240   32    0.75\n",
       "242         240   56    0.25\n",
       "264         262   31    0.50\n",
       "...         ...  ...     ...\n",
       "4983       4912    8    0.40\n",
       "4985       4914   38    0.60\n",
       "4986       4914   43    0.40\n",
       "5033       4961   53    0.75\n",
       "5034       4961   15    0.25\n",
       "\n",
       "[138 rows x 3 columns]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 没有安排的，不为0和1\n",
    "unassigned_df = result[(result.result < THRS)&(result.result > 1-THRS)]\n",
    "unassigned_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-12-566795500ad4>:2: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  assigned_df['family_size'] = FAMILY_SIZE[assigned_df.family_id]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>family_id</th>\n",
       "      <th>day</th>\n",
       "      <th>result</th>\n",
       "      <th>family_size</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>51</td>\n",
       "      <td>1.0</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>25</td>\n",
       "      <td>1.0</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>99</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>52</td>\n",
       "      <td>1.0</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5068</th>\n",
       "      <td>4995</td>\n",
       "      <td>15</td>\n",
       "      <td>1.0</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5069</th>\n",
       "      <td>4996</td>\n",
       "      <td>87</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5070</th>\n",
       "      <td>4997</td>\n",
       "      <td>31</td>\n",
       "      <td>1.0</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5071</th>\n",
       "      <td>4998</td>\n",
       "      <td>91</td>\n",
       "      <td>1.0</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5072</th>\n",
       "      <td>4999</td>\n",
       "      <td>12</td>\n",
       "      <td>1.0</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>4931 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      family_id  day  result  family_size\n",
       "0             0   51     1.0            4\n",
       "1             1   25     1.0            4\n",
       "2             2   99     1.0            3\n",
       "3             3    1     1.0            2\n",
       "4             4   52     1.0            4\n",
       "...         ...  ...     ...          ...\n",
       "5068       4995   15     1.0            4\n",
       "5069       4996   87     1.0            2\n",
       "5070       4997   31     1.0            6\n",
       "5071       4998   91     1.0            5\n",
       "5072       4999   12     1.0            4\n",
       "\n",
       "[4931 rows x 4 columns]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算每天访问的人数（根据assigned_df)\n",
    "assigned_df['family_size'] = FAMILY_SIZE[assigned_df.family_id]\n",
    "assigned_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([290, 271, 294, 293, 263, 242, 223, 247, 273, 297, 288, 292, 275,\n",
       "       250, 238, 272, 292, 292, 271, 248, 223, 244, 264, 291, 292, 296,\n",
       "       273, 249, 234, 251, 278, 283, 252, 235, 205, 184, 202, 233, 253,\n",
       "       231, 210, 183, 204, 229, 247, 281, 256, 223, 204, 198, 222, 248,\n",
       "       255, 223, 208, 185, 173, 196, 219, 198, 174, 141, 124, 121, 149,\n",
       "       170, 178, 158, 136, 123, 125, 135, 158, 185, 158, 135, 121, 125,\n",
       "       142, 167, 186, 167, 138, 120, 128, 155, 174, 203, 177, 152, 130,\n",
       "       122, 130, 158, 178, 158, 128, 125, 122, 124], dtype=int64)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 按照day进行聚合\n",
    "occupancy = assigned_df.groupby('day').family_size.sum().values\n",
    "occupancy # 基于4931个家庭，从第一天到100天每天出现的人数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 4, 0, 0,\n",
       "       0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0,\n",
       "       0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 1], dtype=int64)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "min_occupancy = np.array([max(0, MIN_OCCUPANCY - x) for x in occupancy])\n",
    "min_occupancy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 10,  29,   6,   7,  37,  58,  77,  53,  27,   3,  12,   8,  25,\n",
       "        50,  62,  28,   8,   8,  29,  52,  77,  56,  36,   9,   8,   4,\n",
       "        27,  51,  66,  49,  22,  17,  48,  65,  95, 116,  98,  67,  47,\n",
       "        69,  90, 117,  96,  71,  53,  19,  44,  77,  96, 102,  78,  52,\n",
       "        45,  77,  92, 115, 127, 104,  81, 102, 126, 159, 176, 179, 151,\n",
       "       130, 122, 142, 164, 177, 175, 165, 142, 115, 142, 165, 179, 175,\n",
       "       158, 133, 114, 133, 162, 180, 172, 145, 126,  97, 123, 148, 170,\n",
       "       178, 170, 142, 122, 142, 172, 175, 178, 176], dtype=int64)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max_occupancy = np.array([MAX_OCCUPANCY - x for x in occupancy])\n",
    "max_occupancy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用整数规划进行求解\n",
    "def solveIP(families, min_occupancy, max_occupancy):\n",
    "    # 创建求解器\n",
    "#     solver = pywraplp.Solver('AssignmentProblem', pywraplp.Solver.GLOP_LINEAR_PROGRAMMING)\n",
    "    solver = pywraplp.Solver('AssignmentProblem', pywraplp.Solver.CBC_MIXED_INTEGER_PROGRAMMING)\n",
    "    # 需要安排的家庭\n",
    "    n_families = len(families)\n",
    "    \n",
    "    \n",
    "    x = {} # family_id在第j天是否参观\n",
    "    # 每一天有哪些家庭\n",
    "    candidates = [[] for x in range(N_DAYS)]\n",
    "    for i in families: # family_id\n",
    "        for j in DESIRED[i, :]:#family_id 的choice\n",
    "            candidates[j].append(i) # 在第j天有第i个family参观\n",
    "            # 定义决策变量x[i,j]i代表family_id，j代表第j天参观\n",
    "            x[i, j] = solver.BoolVar('x[%i,%i]' %(i,j))\n",
    "    \n",
    "    # 约束条件\n",
    "    # 每天参观的人数 ， x[i,j]=0或1\n",
    "    daily_occupancy = [solver.Sum([x[i, j] * FAMILY_SIZE[i] for i in candidates[j]]) for j in range(N_DAYS)]# j代表1-100天 \n",
    "    \n",
    "    # 每个家庭，在10个choice中出现的总数\n",
    "    family_presence = [solver.Sum(x[i, j] for j in DESIRED[i,:]) for i in families]\n",
    "    \n",
    "    # 定义目标函数 preference cost部分\n",
    "    preferenece_cost = solver.Sum([pcost_mat[i,j] * x[i,j] for i in families for j in DESIRED[i,:]])\n",
    "    \n",
    "    # 满足preference_cost最小\n",
    "    solver.Minimize(preferenece_cost)\n",
    "    # 认为增加的约束条件 Constraints\n",
    "    for j in range(N_DAYS -1): # j代表当天，j+1前一天\n",
    "        # 当前人数不超过前一天人数+25\n",
    "        solver.Add(daily_occupancy[j] - daily_occupancy[j+1]<=25)\n",
    "        solver.Add(daily_occupancy[j+1] - daily_occupancy[j]<=25)\n",
    "    \n",
    "    # 每个家庭都在10个choice中出现1次\n",
    "    for i in range(n_families):\n",
    "        solver.Add(family_presence[i] == 1)\n",
    "    #每天访问人数约束\n",
    "    for j in range(N_DAYS):\n",
    "        solver.Add(daily_occupancy[j] >= min_occupancy[j])\n",
    "        solver.Add(daily_occupancy[j] <= max_occupancy[j])\n",
    "    result = solver.Solve()\n",
    "    \n",
    "    temp = [(i,j) for i in families for j in DESIRED[i,:] if x[i,j].solution_value() > 0 ]\n",
    "    \n",
    "#     print(solver.Objective().Value()) \n",
    "    # 计算得到剩余家庭的得到参观日期的安排\n",
    "    df = pd.DataFrame(temp, columns=['family_id','day'])\n",
    "    return df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>family_id</th>\n",
       "      <th>day</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>59</td>\n",
       "      <td>38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>240</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>262</td>\n",
       "      <td>31</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>357</td>\n",
       "      <td>24</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>488</td>\n",
       "      <td>39</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64</th>\n",
       "      <td>4869</td>\n",
       "      <td>59</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65</th>\n",
       "      <td>4886</td>\n",
       "      <td>98</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>66</th>\n",
       "      <td>4912</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>67</th>\n",
       "      <td>4914</td>\n",
       "      <td>38</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>68</th>\n",
       "      <td>4961</td>\n",
       "      <td>53</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>69 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    family_id  day\n",
       "0          59   38\n",
       "1         240   32\n",
       "2         262   31\n",
       "3         357   24\n",
       "4         488   39\n",
       "..        ...  ...\n",
       "64       4869   59\n",
       "65       4886   98\n",
       "66       4912   17\n",
       "67       4914   38\n",
       "68       4961   53\n",
       "\n",
       "[69 rows x 2 columns]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 没有安排的family_id\n",
    "unassigned = unassigned_df.family_id.unique()\n",
    "result = solveIP(unassigned, min_occupancy, max_occupancy)\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>family_id</th>\n",
       "      <th>day</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>51</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>25</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>99</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>52</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5068</th>\n",
       "      <td>4995</td>\n",
       "      <td>15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5069</th>\n",
       "      <td>4996</td>\n",
       "      <td>87</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5070</th>\n",
       "      <td>4997</td>\n",
       "      <td>31</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5071</th>\n",
       "      <td>4998</td>\n",
       "      <td>91</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5072</th>\n",
       "      <td>4999</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5000 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      family_id  day\n",
       "0             0   51\n",
       "1             1   25\n",
       "2             2   99\n",
       "3             3    1\n",
       "4             4   52\n",
       "...         ...  ...\n",
       "5068       4995   15\n",
       "5069       4996   87\n",
       "5070       4997   31\n",
       "5071       4998   91\n",
       "5072       4999   12\n",
       "\n",
       "[5000 rows x 2 columns]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.concat((assigned_df[['family_id','day']], result)).sort_values('family_id')\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from numba import njit\n",
    "# 根据安排情况，计算这个安排的preference cost\n",
    "@njit(fastmath=True)\n",
    "def pcost(prediction):\n",
    "    daily_occupancy = np.zeros(N_DAYS+1, dtype=np.int64)\n",
    "    penalty = 0\n",
    "    for (i,p) in enumerate(prediction):\n",
    "        # 计算家庭人数\n",
    "        n = FAMILY_SIZE[i]\n",
    "        # 第i个家庭，p天访问时的cost\n",
    "        penalty += pcost_mat[i, p]\n",
    "        # 计算当天人数\n",
    "        daily_occupancy[p] +=n\n",
    "    return penalty, daily_occupancy\n",
    "# 根据安排情况，计算安排的accounting cost\n",
    "@njit(fastmath=True)\n",
    "def acost(daily_occupancy):\n",
    "    accounting_cost = 0\n",
    "    num_out_of_range = 0\n",
    "    for day in range(N_DAYS):\n",
    "        n_p1 = daily_occupancy[day+1]# 前一天\n",
    "        n = daily_occupancy[day] # 当天\n",
    "        # 如果超爱过了承载范围，则设置out_of_rang\n",
    "        num_out_of_range += (n>MAX_OCCUPANCY) or (n<MIN_OCCUPANCY)\n",
    "        # 计算accouting cost\n",
    "        accounting_cost += acost_mat[n, n_p1]\n",
    "    return accounting_cost, num_out_of_range\n",
    "\n",
    "# 根据安排prediction\n",
    "@njit(fastmath=True)\n",
    "def cost_function(prediction):\n",
    "    # 基于prediction，计算perference cost 和 accounting cost\n",
    "    penalty, daily_occupancy = pcost(prediction) # 统计perference cost 和每天承载数量\n",
    "    accounting_cost, num_out_of_range = acost(daily_occupancy) # 根据每天承载数量，计算accounting cost\n",
    "    final_score = penalty + accounting_cost + num_out_of_range * 99999999\n",
    "    return final_score\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "109274.06836638087"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prediction = df.day.values\n",
    "cost_function(prediction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "submission1.csvsaved\n"
     ]
    }
   ],
   "source": [
    "def save_result(pred, filename):\n",
    "    result = pd.DataFrame(range(N_FAMILY),columns=['family_id'])\n",
    "    result['assigned_day'] = pred + 1\n",
    "    result.to_csv(filename, index=False)\n",
    "    print(filename + 'saved')\n",
    "    return result\n",
    "result = save_result(prediction, 'submission1.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 寻找更好的替代方案"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2070 4520 4188 ... 1081 1099 3186]\n",
      "74036.63763164761\n"
     ]
    }
   ],
   "source": [
    "def find_better(pred):\n",
    "    fobs = np.argsort(FAMILY_SIZE)\n",
    "    print(fobs)\n",
    "    score = cost_function(pred)\n",
    "    original_score = np.inf\n",
    "    \n",
    "    # 如果找不到更新，则退出\n",
    "    while score<original_score:\n",
    "        original_score = score\n",
    "        for family_id in fobs:\n",
    "            for pick in range(10):\n",
    "                # 得到family_id在choice pick 日期day\n",
    "                day = DESIRED[family_id, pick]\n",
    "                # 该family的原有日期oldvalue\n",
    "                oldvalue = pred[family_id]\n",
    "                # 将原有oldvalue 替换为现在的choice pick\n",
    "                pred[family_id] = day\n",
    "                # 重新计算分数\n",
    "                new_score = cost_function(pred)\n",
    "                # 如果比原来分数小，更新choice成功\n",
    "                if new_score < score:\n",
    "                    score = new_score\n",
    "                else: # 设置为原来的oldvalue\n",
    "                    pred[family_id] = oldvalue\n",
    "        print(score,end='\\r')\n",
    "    print(score)\n",
    "new = prediction.copy()\n",
    "find_better(new)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "submission2.csvsaved\n"
     ]
    }
   ],
   "source": [
    "# cost_function(new)\n",
    "result = save_result(new,'submission2.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.0 64-bit ('Bi_env': venv)",
   "language": "python",
   "name": "python38064bitbienvvenvba07af95a1bb4b078aa8134bba84dff2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  },
  "toc-autonumbering": false
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
